import os
import glob

import torch
from torch.utils.data import Dataset
from scipy import signal
from scipy.io import wavfile
import cv2
from PIL import Image
import numpy as np


class Synth90kDataset(Dataset):
    def __init__(self, paths, texts, img_height=32, img_width=100, CHARS=None):
        self.paths = paths
        self.texts = texts
        self.img_height = img_height
        self.img_width = img_width
        
        self.CHARS = CHARS
        self.CHAR2LABEL = {char: i + 1 for i, char in enumerate(CHARS)}
        self.LABEL2CHAR = {label: char for char, label in self.CHAR2LABEL.items()}
    
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]

        try:
            image = Image.open(path).convert('L')  # grey-scale
        except IOError:
            print('Corrupted image for %d' % index)
            return self[index + 1]

        image = image.resize((self.img_width, self.img_height), resample=Image.BILINEAR)
        image = np.array(image)
        image = image.reshape((1, self.img_height, self.img_width))
        image = (image / 127.5) - 1.0

        image = torch.FloatTensor(image)
        text = self.texts[index]
        target = [self.CHAR2LABEL[c] for c in text]
        target_length = [len(target)]

        target = torch.LongTensor(target)
        target_length = torch.LongTensor(target_length)
        return image, target, target_length


def synth90k_collate_fn(batch):
    images, targets, target_lengths = zip(*batch)
    images = torch.stack(images, 0)
    targets = torch.cat(targets, 0)
    target_lengths = torch.cat(target_lengths, 0)
    return images, targets, target_lengths
